#pragma once
#include <ATen/Config.h>
#include <c10/util/Half.h>
#include <c10/util/BFloat16.h>
#include <c10/core/ScalarType.h>

// Defines the accumulation type for a scalar type.
// Example:
//   using accscalar_t = acc_type<scalar_t, /*is_cuda*/true>;
//
// Accumulation types are an important concept in numeric computing
// because you frequently want to perform intermediate computations
// at a higher precision than the input and output precision, to avoid
// compounding internal rounding errors.  Accumulation is the most
// well-known intermediate computation (it is of great importance for
// sum reduction and matrix multiply, for example), but in PyTorch
// acc_type ends up getting used for all sorts of other intermediate
// computations, so it perhaps would be more accurately (ahem) called an
// "accurate" type.  acc_type is especially important for reduced
// precision operations like float16 and bfloat16, where relatively
// benign looking inputs can easily end up overflowing/underflowing.
//
// acc_type is parametrized by whether or not you are running on CUDA
// or not, because on CUDA double precision operations are expensive
// and so by default, we don't actually want to use double as an
// acc_type on CUDA.  A lot of things are typed out below, but
// basically, the table is generated by a few rules:
//
//  If bool:
//      Use 'bool' as acc_type.
//  If floating point:
//      If CUDA, use 'float' as acc_type (unless scalar_t is double),
//      otherwise (CPU) use 'double'
//  If integral:
//      Use 'int64_t' as acc_type
//
// You're not forced to use this template; if you happen to know
// something specific about your use case, you can specify your own
// desired behavior.  This template, however, will give you a reasonable
// default that will work for all dtypes supported in PyTorch.

#if defined(__CUDACC__)
#include <cuda.h>
#include <cuda_fp16.h>
#elif defined(__HIPCC__)
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#endif

namespace at {

template <typename T, bool is_cuda>
struct AccumulateType { };

#if defined(__CUDACC__) || defined(__HIPCC__)
template <> struct AccumulateType<half, true> { using type = float; };
#endif
template <> struct AccumulateType<BFloat16, true> {using type = float; };
template <> struct AccumulateType<Half, true> { using type = float; };
template <> struct AccumulateType<float, true> { using type = float; };
template <> struct AccumulateType<double, true> { using type = double; };
template <> struct AccumulateType<int8_t, true> { using type = int64_t; };
template <> struct AccumulateType<uint8_t, true> { using type = int64_t; };
template <> struct AccumulateType<char, true> { using type = int64_t; };
template <> struct AccumulateType<int16_t, true> { using type = int64_t; };
template <> struct AccumulateType<int32_t, true> { using type = int64_t; };
template <> struct AccumulateType<int64_t, true> { using type = int64_t; };
template <> struct AccumulateType<bool, true> {using type = bool; };
template <> struct AccumulateType<Half, false> { using type = float; };
template <> struct AccumulateType<BFloat16, false> { using type = float; };
template <> struct AccumulateType<c10::complex<float>, false> { using type = c10::complex<double>; };
template <> struct AccumulateType<c10::complex<double>, false> { using type = c10::complex<double>; };
template <> struct AccumulateType<c10::complex<float>, true> { using type = c10::complex<float>; };
template <> struct AccumulateType<c10::complex<double>, true> { using type = c10::complex<double>; };
template <> struct AccumulateType<float, false> { using type = double; };
template <> struct AccumulateType<double, false> { using type = double; };
template <> struct AccumulateType<int8_t, false> { using type = int64_t; };
template <> struct AccumulateType<uint8_t, false> { using type = int64_t; };
template <> struct AccumulateType<char, false> { using type = int64_t; };
template <> struct AccumulateType<int16_t, false> { using type = int64_t; };
template <> struct AccumulateType<int32_t, false> { using type = int64_t; };
template <> struct AccumulateType<int64_t, false> { using type = int64_t; };
template <> struct AccumulateType<bool, false> {using type = bool; };

template<typename T, bool is_cuda>
using acc_type = typename AccumulateType<T, is_cuda>::type;

TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda);

}  // namespace at
